#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>

#include <cassert>
#include <cstdio>
#include <cstdlib>
#include <cstring>

/* Includes, cuda */
#include <cublas_v2.h>
#include <cuda_runtime.h>

#include "type_shim.h"

// BF16 inputs and BF16 accumulation
void gemmex_wrapper_fp16(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k,
                         const float *alpha, at::BFloat16 *A, int lda, at::BFloat16 *B, int ldb, const float *beta,
                         at::BFloat16 *C, int ldc) {
  TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, transa, transb, m, n, k, alpha, A, CUDA_R_16BF, lda, B, CUDA_R_16BF, ldb,
                                    beta, C, CUDA_R_16BF, ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}

// FP16 inputs and FP16 accumulation
void gemmex_wrapper_fp16(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k,
                         const float *alpha, at::Half *A, int lda, at::Half *B, int ldb, const float *beta, at::Half *C,
                         int ldc) {
  TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, transa, transb, m, n, k, alpha, A, CUDA_R_16F, lda, B, CUDA_R_16F, ldb,
                                    beta, C, CUDA_R_16F, ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}

template <typename T>
void wgrad_gemm_accum_fp16_cuda(T *input, T *d_output, T *d_weight, int in_dim, int hidden_dim, int out_dim) {
  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
  cudaStream_t stream;
  cublasGetStream(handle, &stream);
  const float alpha = 1.0;
  const float beta = 1.0;

  gemmex_wrapper_fp16(handle, CUBLAS_OP_N, CUBLAS_OP_T, in_dim, out_dim, hidden_dim, &alpha, input, in_dim, d_output,
                      out_dim, &beta, d_weight, in_dim);
}

template void wgrad_gemm_accum_fp16_cuda<at::Half>(at::Half *input, at::Half *d_output, at::Half *d_weight, int in_dim,
                                                   int hidden_dim, int out_dim);
template void wgrad_gemm_accum_fp16_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *d_output,
                                                       at::BFloat16 *d_weight, int in_dim, int hidden_dim, int out_dim);

void wgrad_gemm_accum_fp16_cuda_stub(at::Tensor &input, at::Tensor &d_output, at::Tensor &d_weight) {
  at::Tensor input_2d, d_output_2d;
  // input tensor: collapse to the first dim
  auto in_sizes = input.sizes();
  if (input.dim() > 2) {
    input_2d = input.view({-1, in_sizes[in_sizes.size() - 1]});
  } else {
    input_2d = input;
  }
  // d_output tensor: collapse to the first dim
  auto d_out_sizes = d_output.sizes();
  if (d_output.dim() > 2) {
    d_output_2d = d_output.view({-1, d_out_sizes[d_out_sizes.size() - 1]});
  } else {
    d_output_2d = d_output;
  }

  const int hidden_dim = input_2d.size(0);
  const int in_dim = input_2d.size(1);
  const int out_dim = d_weight.size(0);

  DISPATCH_HALF_AND_BFLOAT(
      input_2d.scalar_type(), "wgrad_gemm_accum_fp16",
      wgrad_gemm_accum_fp16_cuda<scalar_t>(input_2d.data_ptr<scalar_t>(), d_output_2d.data_ptr<scalar_t>(),
                                           d_weight.data_ptr<scalar_t>(), in_dim, hidden_dim, out_dim););
}
